{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 一、定义前向、后向传播\n",
    "\n",
    "本文将用numpy实现cnn, 并测试mnist手写数字识别\n",
    "\n",
    "如果对神经网络的反向传播过程还有不清楚的，可以参考 [全连接层、损失函数的反向传播](0_1-全连接层、损失函数的反向传播.md)、[卷积层的反向传播](0_2_4-卷积层的反向传播-多通道、有padding、步长不为1.md)、[池化层的反向传播](0_2_5-池化层的反向传播-MaxPooling、AveragePooling、GlobalAveragePooling.md)\n",
    "\n",
    "网络结构如下,包括1个卷积层,1个最大池化层，1个打平层2个全连接层：\n",
    "\n",
    "input(1,28\\*28)=> conv(1,3,3) => relu => max pooling => flatten => fc(64) => relu => fc(10)\n",
    "\n",
    "这里定义卷积层只有一个输出通道,全连接层的神经元也只有64个神经元;主要是由于纯numpy的神经网络比较慢，本文主要目的理解神经网络的反向传播过程，以及如何用numpy实现神经网络，因此不追求设计最合适的网络结构;numpy实现的实现的神经网络无法应用到实际的项目中，请使用深度学习框架(如：Tensorflow、Keras、Caffe等)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义权重、神经元、梯度\n",
    "import numpy as np\n",
    "weights = {}\n",
    "weights_scale = 1e-2\n",
    "filters = 1\n",
    "fc_units=64\n",
    "weights[\"K1\"] = weights_scale * np.random.randn(1, filters, 3, 3).astype(np.float64)\n",
    "weights[\"b1\"] = np.zeros(filters).astype(np.float64)\n",
    "weights[\"W2\"] = weights_scale * np.random.randn(filters * 13 * 13, fc_units).astype(np.float64)\n",
    "weights[\"b2\"] = np.zeros(fc_units).astype(np.float64)\n",
    "weights[\"W3\"] = weights_scale * np.random.randn(fc_units, 10).astype(np.float64)\n",
    "weights[\"b3\"] = np.zeros(10).astype(np.float64)\n",
    "\n",
    "# 初始化神经元和梯度\n",
    "nuerons={}\n",
    "gradients={}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 定义前向传播和反向传播\n",
    "from nn.layers import conv_backward,fc_forward,fc_backward\n",
    "from nn.layers import flatten_forward,flatten_backward\n",
    "from nn.activations import relu_forward,relu_backward\n",
    "from nn.losses import cross_entropy_loss\n",
    "\n",
    "import pyximport\n",
    "pyximport.install()\n",
    "from nn.clayers import conv_forward,max_pooling_forward,max_pooling_backward\n",
    "\n",
    "\n",
    "\n",
    "# 定义前向传播\n",
    "def forward(X):\n",
    "    nuerons[\"conv1\"]=conv_forward(X.astype(np.float64),weights[\"K1\"],weights[\"b1\"])\n",
    "    nuerons[\"conv1_relu\"]=relu_forward(nuerons[\"conv1\"])\n",
    "    nuerons[\"maxp1\"]=max_pooling_forward(nuerons[\"conv1_relu\"].astype(np.float64),pooling=(2,2))\n",
    "\n",
    "    nuerons[\"flatten\"]=flatten_forward(nuerons[\"maxp1\"])\n",
    "    \n",
    "    nuerons[\"fc2\"]=fc_forward(nuerons[\"flatten\"],weights[\"W2\"],weights[\"b2\"])\n",
    "    nuerons[\"fc2_relu\"]=relu_forward(nuerons[\"fc2\"])\n",
    "    \n",
    "    nuerons[\"y\"]=fc_forward(nuerons[\"fc2_relu\"],weights[\"W3\"],weights[\"b3\"])\n",
    "\n",
    "    return nuerons[\"y\"]\n",
    "\n",
    "# 定义反向传播\n",
    "def backward(X,y_true):\n",
    "    loss,dy=cross_entropy_loss(nuerons[\"y\"],y_true)\n",
    "    gradients[\"W3\"],gradients[\"b3\"],gradients[\"fc2_relu\"]=fc_backward(dy,weights[\"W3\"],nuerons[\"fc2_relu\"])\n",
    "    gradients[\"fc2\"]=relu_backward(gradients[\"fc2_relu\"],nuerons[\"fc2\"])\n",
    "    \n",
    "    gradients[\"W2\"],gradients[\"b2\"],gradients[\"flatten\"]=fc_backward(gradients[\"fc2\"],weights[\"W2\"],nuerons[\"flatten\"])\n",
    "    \n",
    "    gradients[\"maxp1\"]=flatten_backward(gradients[\"flatten\"],nuerons[\"maxp1\"])\n",
    "       \n",
    "    gradients[\"conv1_relu\"]=max_pooling_backward(gradients[\"maxp1\"].astype(np.float64),nuerons[\"conv1_relu\"].astype(np.float64),pooling=(2,2))\n",
    "    gradients[\"conv1\"]=relu_backward(gradients[\"conv1_relu\"],nuerons[\"conv1\"])\n",
    "    gradients[\"K1\"],gradients[\"b1\"],_=conv_backward(gradients[\"conv1\"],weights[\"K1\"],X)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 获取精度\n",
    "def get_accuracy(X,y_true):\n",
    "    y_predict=forward(X)\n",
    "    return np.mean(np.equal(np.argmax(y_predict,axis=-1),\n",
    "                            np.argmax(y_true,axis=-1)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 二、加载数据\n",
    "\n",
    "mnist.pkl.gz数据源： http://deeplearning.net/data/mnist/mnist.pkl.gz   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from nn.load_mnist import load_mnist_datasets\n",
    "from nn.utils import to_categorical\n",
    "train_set, val_set, test_set = load_mnist_datasets('mnist.pkl.gz')\n",
    "train_x,val_x,test_x=np.reshape(train_set[0],(-1,1,28,28)),np.reshape(val_set[0],(-1,1,28,28)),np.reshape(test_set[0],(-1,1,28,28))\n",
    "train_y,val_y,test_y=to_categorical(train_set[1]),to_categorical(val_set[1]),to_categorical(test_set[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x.shape:(16, 1, 28, 28),y.shape:(16, 10)\n"
     ]
    }
   ],
   "source": [
    "# 随机选择训练样本\n",
    "train_num = train_x.shape[0]\n",
    "def next_batch(batch_size):\n",
    "    idx=np.random.choice(train_num,batch_size)\n",
    "    return train_x[idx],train_y[idx]\n",
    "\n",
    "x,y= next_batch(16)\n",
    "print(\"x.shape:{},y.shape:{}\".format(x.shape,y.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 三、训练网络\n",
    "\n",
    "由于numpy卷积层层前向、后向过程较慢,这里只迭代2000步,mini-batch设置为2;实际只训练了4000个样本(也有不错的精度,增加迭代次数精度会继续提升;增加卷积层输出通道数，精度上限也会提升);总样本有5w个。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " step:0 ; loss:2.3025710785961633\n",
      " train_acc:0.5;  val_acc:0.105\n",
      "\n",
      " step:100 ; loss:2.322658576777174\n",
      " train_acc:0.0;  val_acc:0.135\n",
      "\n",
      " step:200 ; loss:2.2560641373902453\n",
      " train_acc:0.0;  val_acc:0.15\n",
      "\n",
      " step:300 ; loss:2.1825470524006914\n",
      " train_acc:1.0;  val_acc:0.105\n",
      "\n",
      " step:400 ; loss:2.208445091755495\n",
      " train_acc:0.0;  val_acc:0.12\n",
      "\n",
      " step:500 ; loss:1.413758817626698\n",
      " train_acc:0.5;  val_acc:0.475\n",
      "\n",
      " step:600 ; loss:0.8138671602711395\n",
      " train_acc:1.0;  val_acc:0.605\n",
      "\n",
      " step:700 ; loss:0.040969240382020794\n",
      " train_acc:1.0;  val_acc:0.695\n",
      "\n",
      " step:800 ; loss:0.2943919590130214\n",
      " train_acc:1.0;  val_acc:0.8\n",
      "\n",
      " step:900 ; loss:0.7937038773889639\n",
      " train_acc:0.5;  val_acc:0.775\n",
      "\n",
      " step:1000 ; loss:0.20416262923266468\n",
      " train_acc:1.0;  val_acc:0.82\n",
      "\n",
      " step:1100 ; loss:3.492562642433139\n",
      " train_acc:0.5;  val_acc:0.755\n",
      "\n",
      " step:1200 ; loss:0.44327566847604044\n",
      " train_acc:1.0;  val_acc:0.81\n",
      "\n",
      " step:1300 ; loss:0.381620659555296\n",
      " train_acc:1.0;  val_acc:0.78\n",
      "\n",
      " step:1400 ; loss:0.1379428630137357\n",
      " train_acc:1.0;  val_acc:0.715\n",
      "\n",
      " step:1500 ; loss:0.0048211652445979145\n",
      " train_acc:1.0;  val_acc:0.78\n",
      "\n",
      " step:1600 ; loss:0.6156347089073209\n",
      " train_acc:1.0;  val_acc:0.78\n",
      "\n",
      " step:1700 ; loss:2.9270997739154003\n",
      " train_acc:0.5;  val_acc:0.84\n",
      "\n",
      " step:1800 ; loss:0.7148056981166203\n",
      " train_acc:1.0;  val_acc:0.845\n",
      "\n",
      " step:1900 ; loss:3.3810034206400825\n",
      " train_acc:0.5;  val_acc:0.745\n",
      "\n",
      " final result test_acc:0.8279;  val_acc:0.839\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from nn.optimizers import SGD\n",
    "# 初始化变量\n",
    "batch_size=2\n",
    "steps = 2000\n",
    "\n",
    "# 更新梯度\n",
    "sgd=SGD(weights,lr=0.01,decay=1e-6)\n",
    "\n",
    "for s in range(steps):\n",
    "    X,y=next_batch(batch_size)\n",
    "\n",
    "    # 前向过程\n",
    "    forward(X)\n",
    "    # 反向过程\n",
    "    loss=backward(X,y)\n",
    "\n",
    "\n",
    "    sgd.iterate(weights,gradients)\n",
    "\n",
    "    if s % 100 ==0:\n",
    "        print(\"\\n step:{} ; loss:{}\".format(s,loss))\n",
    "        idx=np.random.choice(len(val_x),200)\n",
    "        print(\" train_acc:{};  val_acc:{}\".format(get_accuracy(X,y),get_accuracy(val_x[idx],val_y[idx])))\n",
    "\n",
    "print(\"\\n final result test_acc:{};  val_acc:{}\".\n",
    "      format(get_accuracy(test_x,test_y),get_accuracy(val_x,val_y)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADFCAYAAAARxr1AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAADF1JREFUeJzt3XuMVPUVB/DvYV1YXcQC8n6I0lVB\na9e6BRWsGqpiawNGtJK0JS31VWhsrI2ExmAaqqRRkFQFsW5BY121D8CE+sjWahuQ8hAFihRqEBcW\nVoSGVQPs4/SPuZusO2fOzsyde+fB95OQnTl7597fBL7cub+591xRVRCRrUe+B0BUyBgQIgcDQuRg\nQIgcDAiRgwEhcjAgRA4GhMjBgBA5TgnzYhGZDGAxgDIAv1PVBd7yPaWXVqAyzCaJcuIYPsMJPS7d\nLSfZnmoiImUA/gPgGgANADYAmK6q/071mj7ST8fLpKy2R5RL67UeR/VwtwEJ8xFrHIDdqvqBqp4A\nUAdgSoj1ERWcMAEZBuCjTs8bgtoXiMjtIrJRRDa24HiIzRHFL0xArN1T0uc1VV2mqjWqWlOOXiE2\nRxS/MAFpADCi0/PhAPaHGw5RYQkTkA0AqkTkbBHpCeBWAKtzMyyiwpD1NK+qtorIbACvIjHNW6uq\n23M2MqICEOp7EFVdA2BNjsZCVHD4TTqRgwEhcjAgRA4GhMjBgBA5GBAiBwNC5GBAiBwMCJGDASFy\nMCBEDgaEyMGAEDlCnc1LJw/pZV8NuvvBi816W2W7WR9ab/dJ6P3S+uwGFjHuQYgcDAiRgwEhcjAg\nRA4GhMjBWawS0qOiwq7372fW25oOJdW05YS57IE7LjHrf7/5N2Z9SNlpZr1m22yz3ltSdAHN812Y\nwzav3gOgGUAbgFZVrcnFoIgKRS72IFeravJ/RUQlgMcgRI6wAVEAr4nIJhG53VqAzaupmIX9iDVB\nVfeLyEAAr4vI+6r6VucFVHUZgGVA4v4gIbdHFKusb6CTtCKRBwB8qqoPp1qGN9DJjR6n2TNEjT+u\nNuub7nvMrI9dPiupNuqX6zIai17+VbNeW/e4WU81u3Xl3XeZ9co/RnOOVuQ30BGRShE5veMxgGsB\nbMt2fUSFKMxHrEEA/iKJ+etTAPxBVV/JyaiICkSY7u4fALD3r0QlgtO8RA4GhMjBc7EKWNnYc836\n54/a50ttusCerUql90WfZDymrmTtu2Z97bGk+7kCAG6qPGLWj44sM+uV2Q0rZ7gHIXIwIEQOBoTI\nwYAQORgQIgdnsQpAWdU5Zn3n7faVgM98+YmM1v/y533Muqzsn9F6MjHnr9PN+k3T7LE/89NFZv2+\nheNzNqZscA9C5GBAiBwMCJGDASFyMCBEDs5ixejEdXZXpHlP1Jr1KypaM1r/rw59xaxvvG64We9/\nILOrBzMx59rVGS0/be2dZn003snFcLLGPQiRgwEhcjAgRA4GhMjBgBA5up3FEpFaADcAaFLVC4Na\nPwAvABgFYA+AW1TVvlSsxJVdcJ5ZP3JR36Ta/PlPmcummq3a2/q5WX/6yGVm/ZXHJ5r1KGerUnno\nzRvM+szvLDXrbcftKwrzLZ09yHIAk7vU5gCoV9UqAPXBc6KS021Aglaih7uUpwBYETxeAWBqjsdF\nVBCyPQYZpKqNABD8HJhqQTavpmIW+UG6qi5T1RpVrSmHfa9tokKV7akmB0VkiKo2isgQAE25HFQx\n+Wi+fXC5ZVxmFzVZpr5zm1kfPHWHWe+P+A/GU7lzwhsZLX/+wk/NensuBhNCtnuQ1QBmBI9nAFiV\nm+EQFZZuAyIizwNYB+A8EWkQkZkAFgC4RkR2AbgmeE5Ucrr9iKWq9sXFAG/0QSWP36QTORgQIgcv\nmOoi1e3N3l90gVn//UVPp73uVKeOpJqtGn6XffZOZpdR5cfSf11p1u+dvNOsN1xvtzgamud7lnEP\nQuRgQIgcDAiRgwEhcjAgRI6TdharR0WFWd+59HyzvnvSkxmtf8NxTardM/dec9nBdW+b9WKYrWr5\n5iVmff21i1O84lSzOvy3m816sZ6LRXRSYECIHAwIkYMBIXIwIESOkp/FkvKeZv3Aj75m1ndNeiyj\n9e9rs8+v+uGmu5Jq57y5x1y2GGarysZUmfXd3y43673Frtd9OsCsa1u+56ts3IMQORgQIgcDQuRg\nQIgcDAiRI9vm1Q8AuA3Ax8Fic1V1TVSDDGPP/fa5QttnZjZb1ZRitmrqg78w6yOXJveoKubZqqtf\nss+VernvLrP+XPNQs153k93rQ1vsKw3zLdvm1QCwSFWrgz8FGQ6isLJtXk10UghzDDJbRN4TkVoR\nSb4ZRoDNq6mYZRuQJQBGA6gG0AjgkVQLsnk1FbOsAqKqB1W1TVXbATwFYFxuh0VUGLI6F6ujs3vw\n9EYAee5eBLRPrDbrwy7fl9F67m+y1/Pq0glmfcCThdNRPZWyc0eb9QOTkm/rMn3Wa+ay96SYrXq2\nebBZX7RkmlkfvH2tWS9U6UzzPg/gKgBnikgDgHkArhKRagCKxD0K74hwjER5k23z6vTbCRIVMX6T\nTuRgQIgcDAiRo2SuKPzvLfZ3LLvGrDTrqc6tevs+e8b6zNcKf7Zq77zLzfqSHyw161dUpH922GVb\nvmvWB/zkmFkf/GFxzValwj0IkYMBIXIwIEQOBoTIUTIH6ede2JDR8uUiZr2lssxePuMRJTtl+DCz\nvnf6WRmt59Jp75r1uqEPm/U+PexG3RPfuzmp9r91g8xlz3l6j1lv3bffrJcK7kGIHAwIkYMBIXIw\nIEQOBoTIIarJtwqLSh/pp+PFbvsS1te3tJn1+QO3mvUWtZdfd9yexcqFSjlh1qt75mYyceVnXzLr\n9z/7PbM+csHGpJq22GMsNeu1Hkf1sD2V2Qn3IEQOBoTIwYAQORgQIgcDQuRIp6vJCADPABiMxH3d\nl6nqYhHpB+AFAKOQ6Gxyi6oeiW6ovr89ZLflaVton7fUA/YExoReUd4KLLPZqm9stVvnnPbQGWa9\nfNuHZn3EJ/bFS/HNXxavdPYgrQB+rqpjAFwKYJaIjAUwB0C9qlYBqA+eE5WUdJpXN6rq5uBxM4Ad\nAIYBmAJgRbDYCgBToxokUb5kdAwiIqMAXAxgPYBBHd0Vg5/JbfrA5tVU3NIOiIj0BvAnAD9T1aPp\nvo7Nq6mYpRUQESlHIhzPqeqfg/JBERkS/H4IgKZohkiUP+nMYgkSrUZ3qOrCTr9aDWAGgAXBz1WR\njDBNZ6zaYtYvO3WWWf94vH0uVq70PJR8Ttfohe9ntI7Tm/eadW212/VE+45OTunMO04A8H0AW0Wk\n41/hXCSC8aKIzASwF0Dy9ZtERS6d5tX/BFJ8aQBEc2ouUYHgN+lEDgaEyMGAEDlKpi9W+zG7iXLf\n5XbT6b7LIxxMCpxlKj7cgxA5GBAiBwNC5GBAiBwMCJGDASFyMCBEDgaEyMGAEDkYECIHA0LkYECI\nHAwIkYMBIXIwIEQOBoTI0W1ARGSEiLwhIjtEZLuI3B3UHxCRfSKyJfjzreiHSxSvdK4o7GhevVlE\nTgewSUReD363SFUfjm54RPmVTtufRgAdPXibRaSjeTVRyQvTvBoAZovIeyJSKyJ9U7yGzaupaIVp\nXr0EwGgA1UjsYR6xXsfm1VTMsm5eraoHVbVNVdsBPAVgXHTDJMqPdGaxzObVHZ3dAzcC2Jb74RHl\nV5jm1dNFpBqJW93tAXBHJCMkyqMwzavX5H44RIWF36QTORgQIgcDQuRgQIgcDAiRgwEhcjAgRA4G\nhMjBgBA5RFXj25jIxwA+DJ6eCeBQbBvPH77PwnSWqg7obqFYA/KFDYtsVNWavGw8RnyfxY0fsYgc\nDAiRI58BWZbHbceJ77OI5e0YhKgY8CMWkYMBIXLEHhARmSwiO0Vkt4jMiXv7UQraHzWJyLZOtX4i\n8rqI7Ap+mu2RionTbbPk3musARGRMgCPA7gewFgkrmsfG+cYIrYcwOQutTkA6lW1CkB98LzYdXTb\nHAPgUgCzgr/Hknuvce9BxgHYraofqOoJAHUApsQ8hsio6lsADncpTwGwIni8AsDUWAcVAVVtVNXN\nweNmAB3dNkvuvcYdkGEAPur0vAGl38Z0UNC+taON68A8jyenunTbLLn3GndArO4onGcuUka3zZIT\nd0AaAIzo9Hw4gP0xjyFuBzua7AU/m/I8npywum2iBN9r3AHZAKBKRM4WkZ4AbgWwOuYxxG01gBnB\n4xkAVuVxLDmRqtsmSvG9xv1NenCjnUcBlAGoVdVfxzqACInI8wCuQuLU74MA5gFYCeBFACMB7AVw\ns6p2PZAvKiIyEcA/AGwF0B6U5yJxHFJa75WnmhClxm/SiRwMCJGDASFyMCBEDgaEyMGAEDkYECLH\n/wF6bHULzFYZBAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7fd3269bd4a8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y_true:0,y_predict:0\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADFCAYAAAARxr1AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAC95JREFUeJzt3X+MFPUZx/H3ww/RUEAPBREOfwCt\nkprSgmBKo1hiC00j2FQDppY0KLaV1lrThviP0MRoTK01rVVPIdDEoiatQo1pNVcbamoQUAsKVSmC\nHhBORQo1ih739I+bMyf3ne8tu7Ozs8vnlZDdfW5257tcPjc73515xtwdEQnrV+sBiBSZAiISoYCI\nRCggIhEKiEiEAiISoYCIRCggIhEKiEjEgEqebGazgLuB/sCD7n57bPkTbJCfyOBKVimSiQ95n4/8\nsPW1nJV7qImZ9QdeAy4F2oANwHx335r2nKHW5NNsZlnrE8nSem/loO/vMyCVfMSaCmx39x3u/hHw\nMDCngtcTKZxKAjIaeKvH47ak9ilmtsjMNprZxo85XMHqRPJXSUBCm6den9fcvcXdp7j7lIEMqmB1\nIvmrJCBtQHOPx2OAPZUNR6RYKgnIBmCCmZ1tZicA84C12QxLpBjKnuZ19w4zWwz8la5p3hXu/kpm\nIxMpgIq+B3H3J4EnMxqLSOHom3SRCAVEJEIBEYlQQEQiFBCRCAVEJEIBEYlQQEQiFBCRCAVEJEIB\nEYlQQEQiFBCRCAVEJEIBEYlQQEQiFBCRCAVEJEIBEYmo6Jx0SXfZ1nd71RYN2xlcduYPfxCsn7Tm\n+SyHVBUdMycH67sWHgnWx131UjWHk7lKm1fvBA4BR4AOd5+SxaBEiiKLLcgl7v5OBq8jUjjaBxGJ\nqDQgDjxlZpvMbFFoATWvlnpW6Ues6e6+x8xGAE+b2b/dfV3PBdy9BWiBruuDVLg+kVxV2llxT3Lb\nbmaP0XXNkHXxZzWYqecHy4uGrexV66QzuOz+c8O/htFryh5VbtJmq165+IFgffrCHwfrw5c/l9mY\nslT2RywzG2xmQ7rvA18DXs5qYCJFUMkWZCTwmJl1v84f3P0vmYxKpCAq6e6+A/hChmMRKRxN84pE\nKCAiEToWq0T9Tx4WrB/8xfvBer/gJRzDf49GP3Oo3GHV3I2TWoP1finv9cNT+7zycqFoCyISoYCI\nRCggIhEKiEiEAiISoVmsEh3+0vhgvfX8+4L1zsDfnos2XxlcdujzW8ofWF6O4ZgzSD/u7MyHdgXr\nHWUNqvq0BRGJUEBEIhQQkQgFRCRCARGJ0CzWUQY0jwnWL/vtU8F62jFH+4580Kt2yemvB5fdVAd/\np3ZfMiRYH2j9g/WPU06u7mjbndWQclH834xIDSkgIhEKiEiEAiISoYCIRPQ5i2VmK4BvAu3u/vmk\n1gQ8ApwF7ASudPf3qjfM/Gy95fRg/fFh4SZVacccXb70Z71qw7f8L2WtxTkWK3UWb96zwfrHHu6L\ndc+BcZmNqZZK2YKsBGYdVVsCtLr7BKA1eSzScPoMSNJKdP9R5TnAquT+KmBuxuMSKYRy90FGuvte\ngOR2RNqCal4t9azqO+nu3uLuU9x9ykAGVXt1Ipkq91CTfWY2yt33mtkooD3LQeVh17IvB+uvzf5N\nsB46dATCO+MATSt6N2Ouh9b2/7lmbLD++IjwJEXa/8tDd84O1psoZpPqNOVuQdYCC5L7C4A66EMu\ncuz6DIiZrQaeAz5nZm1mthC4HbjUzF4HLk0eizScPj9iufv8lB/NzHgsIoWjb9JFIhQQkYjGP2Eq\npV3NyqvDs1Vph47M3hS8RilnBGar6tn0r28O1tP+X373bng2MDSLV4+0BRGJUEBEIhQQkQgFRCRC\nARGJaPhZrMn3/ytYv2BQ+FJgGw6H/2accVu4vU09+2DO1F61lubSm3EDrN7U+zUAPsvG8gdWINqC\niEQoICIRCohIhAIiEqGAiEQ0zixWyjFX3x9+b7DeyUnB+v3tM8KvXw+XSTtGJ//0zV61zpTzHtOO\nxZq4dG+wXtRLqh0rbUFEIhQQkQgFRCRCARGJUEBEIsptXr0UuBZ4O1nsZnd/slqDLEnKLNPiN74d\nrD82Pjzclua/B+v9doeP3bpoS/j1T7zjlGA9ZMDfNpW8bDkOzr8wWL+r+Z5etX6E3+d9B8YH66/e\ncVqwfvE5/w3W160Lzzae8/NinoFYbvNqgLvcfVLyr7bhEKmScptXixwXKtkHWWxmm81shZmlfp5Q\n82qpZ+UG5F5gHDAJ2AvcmbagmldLPSsrIO6+z92PuHsn8AAQPmtGpM6Ze989x83sLOCJHrNYo7qv\nD2JmNwLT3H1eX68z1Jp8muXbsXTAmNHB+o5rzgzWH/zub4P1qYOO7RilfoG/PWnLLmufHKxnZdmI\nF4P10HhC405bFuDFlDMwv7fqR8H62GX/DNbztt5bOej7w1N2PZQyzbsamAGcamZtwC3ADDObRFdH\n/53AdRWNVqSgym1evbwKYxEpHH2TLhKhgIhEKCAiESXNYmWlFrNYWQn1kALYf254N+6yec9WczhB\nkwfvDNbnDj4QrIfOHpy84TvBZU9ZPiRYP/HPz5c2uIIpdRZLWxCRCAVEJEIBEYlQQEQitJPeQF57\ncEq4Pvv+YD10+Mi3ps0NLtvRtrv8gRWQdtJFMqCAiEQoICIRCohIhAIiEtE4zauPIx1fDZ9glTZb\nldbKZ+ptN/SqjWgrxglNRaEtiEiEAiISoYCIRCggIhEKiEhEKV1NmoHfA6cDnUCLu99tZk3AI8BZ\ndHU2udLd36veUKXbG5eHf21prXk2pbTmGdX6dq/akfKH1ZBK2YJ0ADe5+3nAhcD1ZjYRWAK0uvsE\noDV5LNJQSmlevdfdX0juHwK2AaOBOcCqZLFVQPgwUJE6dkz7IEmHxS8C64GR3d0Vk9sRKc9R82qp\nWyUHxMw+A/wR+Im7Hyz1eWpeLfWspICY2UC6wvGQu/8pKe8zs1HJz0cB7dUZokjtlDKLZXS1Gt3m\n7r/q8aO1wALg9uR2TVVGeBwb0DwmWL9pZviCXmmNpxffujhYH76tmJc9K5JSDlacDlwNbDGzl5La\nzXQF41EzWwi8CVxRnSGK1E4pzaufhZTDQUEnmEtD0zfpIhEKiEiEAiISoTMKC2zXVWOD9UXDwhOG\nacdiDV+u2apyaQsiEqGAiEQoICIRCohIhAIiEqFZrDqU1ufqvgPjcx5J49MWRCRCARGJUEBEIhQQ\nkQgFRCRCs1gFNnxrR7DeSfi6kotO3h6sP8EFmY3peKMtiEiEAiISoYCIRCggIhHmHt7h+2SB9ObV\nS4Frge4OyDe7e7gfTWKoNfk0U58Hqb313spB35/WjOQTpcxidTevfsHMhgCbzOzp5Gd3ufsvKxmo\nSJGV0vZnL9Ddg/eQmXU3rxZpeJU0rwZYbGabzWyFmZ2S8hw1r5a6VUnz6nuBccAkurYwd4aep+bV\nUs/Kbl7t7vvc/Yi7dwIPAFOrN0yR2ugzIGnNq7s7uycuB17OfngitVVJ8+r5ZjYJcLquUXhdVUYo\nUkOVNK+Ofuch0gj0TbpIhAIiEqGAiEQoICIRCohIhAIiEqGAiEQoICIRCohIRJ9nFGa6MrO3gV3J\nw1OBd3Jbee3ofRbTme5+Wl8L5RqQT63YbKO7T6nJynOk91nf9BFLJEIBEYmoZUBaarjuPOl91rGa\n7YOI1AN9xBKJUEBEInIPiJnNMrNXzWy7mS3Je/3VlLQ/ajezl3vUmszsaTN7PbkNtkeqJ2bWbGbP\nmNk2M3vFzG5I6g33XnMNiJn1B+4BZgMT6TqvfWKeY6iylcCso2pLgFZ3nwC0Jo/rXXe3zfOAC4Hr\nk99jw73XvLcgU4Ht7r7D3T8CHgbm5DyGqnH3dcD+o8pzgFXJ/VXA3FwHVQXuvtfdX0juHwK6u202\n3HvNOyCjgbd6PG6j8duYjkzat3a3cR1R4/Fk6qhumw33XvMOSKg7iuaZ61Sg22bDyTsgbUBzj8dj\ngD05jyFv+7qb7CW37TUeTyZC3TZpwPead0A2ABPM7GwzOwGYB6zNeQx5WwssSO4vANbUcCyZSOu2\nSSO+17y/STezbwC/BvoDK9z91lwHUEVmthqYQdeh3/uAW4DHgUeBscCbwBXufvSOfF0xs68A/wC2\n0HVRJejqtrmeRnuvOtREJJ2+SReJUEBEIhQQkQgFRCRCARGJUEBEIhQQkYj/A+DydaCAaM+NAAAA\nAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7fd326916390>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y_true:4,y_predict:4\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADFCAYAAAARxr1AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAADFVJREFUeJzt3X9sVfUZBvDnbQURJoZOwQrMoiuO\nRmfNanVRN5SwIRCLMzL5w2HiFBdQF9wSRrLpH3MxmYouYWiZaI0OMVMnEhRdN1edC4LEIcoPsaJW\nGhCr0DAR2vvuj3u6VO573t7ec+6vw/NJSO99Ofec70Wfnnu+95z3iKqCiGwVxR4AUSljQIgcDAiR\ngwEhcjAgRA4GhMjBgBA5GBAiBwNC5DguyotFZBqA+wFUAviTqt7lLT9UjtdhGBFlk0SxOISDOKxf\nykDLSa6nmohIJYAdAKYC6ACwAcAcVX0n7DUjpUovkCk5bY8oTuu1FQe0a8CARPmI1Qhgp6q2q+ph\nAE8AaIqwPqKSEyUgYwF81O95R1D7ChG5UUQ2isjGI/gywuaICi9KQKzdU8bnNVVtVtUGVW0YguMj\nbI6o8KIEpAPA+H7PxwHYHW04RKUlSkA2AKgVkQkiMhTANQBWxzMsotKQ8zSvqvaIyAIA65Ce5l2h\nqm/HNjKiEhDpexBVXQtgbUxjISo5/CadyMGAEDkYECIHA0LkYECIHAwIkYMBIXIwIEQOBoTIwYAQ\nORgQIgcDQuRgQIgckc7mpcGprJto1vU4+/dU++xRZv1w9RGzfl3Da2Z98clvZTG6tDntPzTr/72i\n16z3fvZZ1usuR9yDEDkYECIHA0LkYECIHAwIkYOzWHmy44HGjNr6GUvMZU+qGBrLNitCft+lkMp6\nHY+f8bxZb3z4J2b91FnJnsWK2rx6F4BuAL0AelS1IY5BEZWKOPYgl6rqvhjWQ1RyeAxC5IgaEAXw\nooi8ISI3WguweTWVs6gfsS5S1d0iMhrASyKyTVXb+i+gqs0AmoH0/UEibo+ooKJ2Vtwd/NwrIs8g\nfc+QNv9VyZK6uN6s/3P6vRm1kyry293+/Z5DZn3Gv+ZnvY5hJxw2698f955ZfzfrNZennD9iicgI\nETmx7zGAHwDYEtfAiEpBlD3IGADPiEjfev6sqi/EMiqiEhGlu3s7gHNjHAtRyeE0L5GDASFy8Fys\niDoX2lf3janMfsaq/Yi9jhltC8z6pNs/NeupfV1m/YzuN7MeC8S+M/J7w4eHvCDZ321xD0LkYECI\nHAwIkYMBIXIwIEQOzmJFNKTS7hdl6eq1Z3wWXvFTs167eZNZ78l6izlQ+3xSGWbPylUOta+GTEq/\nLO5BiBwMCJGDASFyMCBEDgaEyMFZrAI6scL+5w7r4l6zOZ+jsR2edr5Zn3733836H1+7zKxPvHFD\nbGMqJu5BiBwMCJGDASFyMCBEDgaEyDHgLJaIrAAwE8BeVT07qFUBWAWgBsAuALNVNRkn3wzSoU1V\nZr3rvMzzrqpCrjJ8/brMHloA0CgLzXrNr1+3B5OyzwurGDHCrG9f+q2MWuuldgf6Uyrt/1XWPDfF\nHktCZLMHeQTAtKNqiwC0qmotgNbgOVHiDBiQoJXo0Rc7NwFoCR63AJgV87iISkKuxyBjVLUTAIKf\no8MWZPNqKmd5P0hX1WZVbVDVhiHIb29aoriJhlwg85WFRGoArOl3kL4dwGRV7RSRagAvq+pZA61n\npFTpBZLsg7o+FedOyqid/fA2c9nfjgk56A5xzqO3mPXTXrUvpbp5ySqzPnOE3T7IXHbbj8x6xZSP\nsl5HKVmvrTigXXaPo35y3YOsBjA3eDwXwLM5roeopA0YEBFZCeDfAM4SkQ4RuR7AXQCmisi7AKYG\nz4kSZ8DvQVR1TshfHRufleiYxm/SiRwMCJEjq1msuBxLs1iWsFM+Rr1ot855uObFwa0/5PddCqms\n13FOW0gLol/ZZxL17Pow63WXknzPYhEdExgQIgcDQuRgQIgcDAiRg21/Cih18KBZ33/tKWb9+XV2\nO6DLh8dzbdp9XXUZtdpf7jOX7en4OJZtlhvuQYgcDAiRgwEhcjAgRA4GhMjBWawC+vSG75r1H99q\nn3M1Y/j+kDXZv9eGSKVZPxJyul2FGOdoyYCnJx1TuAchcjAgRA4GhMjBgBA5GBAiR67Nq+8AcAOA\nT4LFFqvq2nwNshx90dSYUVv7m7vNZU+qsK8ozP46wLQJz91k1h+b+qBZv2VUZp+uB373PXPZb17b\nMcjRJEOuzasBYImq1gd/GA5KpFybVxMdE6IcgywQkc0iskJE7POywebVVN5yDcgyAGcCqAfQCeCe\nsAXZvJrKWU4BUdU9qtqrqikAywFkHpESJUBO52KJSHXf/UEAXAlgS3xDSobusZn/tGGzVYP1nfXX\nmfWJ8zaY9TvrrjHrVz3VllFbc8lSc9mF515v1lP/2WrWkyKbad6VACYDOFlEOgDcDmCyiNQDUKTv\nUTgvj2MkKppcm1c/lIexEJUcfpNO5GBAiBwMCJGDVxTmyf7G6F+KLvu81qyffsvnZt2+QyHQ+84O\ns/6H5sz7Ds657T5z2dqHdpr1HRfZ323pl8n4Uph7ECIHA0LkYECIHAwIkYMH6Xky8+zNkdex/LHp\nZn1cx2uR1w0Apy7JXM+T88aZy/6+2t7mtMk/M+tD123MfWAlhHsQIgcDQuRgQIgcDAiRgwEhcnAW\nqwRsP9Jr1se+bN+yLZ8efP8Ssz7n20+Y9V1X2c2uJ66LbUhFxT0IkYMBIXIwIEQOBoTIwYAQObLp\najIewKMATkW6n3Kzqt4vIlUAVgGoQbqzyWxVjecO9wnw8srzM4sL7fOZPukdYdYrDtmXQFWMHGlv\ntNK+BVuYbUvOyKhtOWdZyNL2bNUJVV8MapvlJps9SA+A21R1EoALAcwXkToAiwC0qmotgNbgOVGi\nZNO8ulNVNwWPuwFsBTAWQBOAlmCxFgCz8jVIomIZ1DGIiNQAOA/AegBj+rorBj9Hh7yGzaupbGUd\nEBH5GoCnAPxcVQ9k+zo2r6ZyllVARGQI0uF4XFWfDsp7RKQ6+PtqAHvzM0Si4slmFkuQbjW6VVXv\n7fdXqwHMBXBX8PPZvIywTI39W+aEXvvNR8xlLx5mr+PTVWvMev3xu836hOPsFaUGdTM3e7ZqVXe1\nvc3b9pv1sBZE5SabkxUvAnAtgLdE5M2gthjpYDwpItcD+BDA1fkZIlHxZNO8+lWE/VoBpsQ7HKLS\nwm/SiRwMCJGDASFy8IrCPLFuTTbzlfnmsu9c1mzWm0bsC1l7PLdys4Q1zG5ZavfoGv1BPD26ShX3\nIEQOBoTIwYAQORgQIgcDQuTgLFYBnfa0Pfs0/ZGbzHpqkT2L9ULdXwa13TUHv27Wf9E2O6NWd0en\nuezomDrKlxvuQYgcDAiRgwEhcjAgRA4GhMghqlqwjY2UKr1AeAkJFd96bcUB7Qq7zun/uAchcjAg\nRA4GhMjBgBA5BgyIiIwXkX+IyFYReVtEbg3qd4jIxyLyZvDHvqKGqIxlcy5WX/PqTSJyIoA3ROSl\n4O+WqOrd+RseUXFl0/anE0BfD95uEelrXk2UeFGaVwPAAhHZLCIrRGRUyGvYvJrKVpTm1csAnAmg\nHuk9zD3W69i8mspZzs2rVXWPqvaqagrAcgCN+RsmUXFkM4tlNq/u6+weuBLAlviHR1RcUZpXzxGR\negCK9D0K5+VlhERFFKV59dr4h0NUWvhNOpGDASFyMCBEDgaEyMGAEDkYECIHA0LkYECIHAwIkaOg\nbX9E5BMAHwRPTwYQdo+xJOH7LE2nq+opAy1U0IB8ZcMiG1W1oSgbLyC+z/LGj1hEDgaEyFHMgNj3\nPk4evs8yVrRjEKJywI9YRA4GhMhR8ICIyDQR2S4iO0VkUaG3n09B+6O9IrKlX61KRF4SkXeDn2Z7\npHLidNtM3HstaEBEpBLAUgCXA6hD+rr2ukKOIc8eATDtqNoiAK2qWgugNXhe7vq6bU4CcCGA+cF/\nx8S910LvQRoB7FTVdlU9DOAJAE0FHkPeqGobgK6jyk0AWoLHLQBmFXRQeaCqnaq6KXjcDaCv22bi\n3muhAzIWwEf9nncg+W1MxwTtW/vauI4u8nhidVS3zcS910IHxOqOwnnmMmV020ycQgekA8D4fs/H\nAdhd4DEU2p6+JnvBz71FHk8srG6bSOB7LXRANgCoFZEJIjIUwDUAVhd4DIW2GsDc4PFcAM8WcSyx\nCOu2iSS+10J/kx7caOc+AJUAVqjqnQUdQB6JyEoAk5E+9XsPgNsB/BXAkwC+AeBDAFer6tEH8mVF\nRC4G8AqAtwCkgvJipI9DkvVeeaoJUTh+k07kYECIHAwIkYMBIXIwIEQOBoTIwYAQOf4HIaN37lb6\ncNAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7fd3268e4710>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y_true:8,y_predict:8\n"
     ]
    }
   ],
   "source": [
    "# 随机查看预测结果\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "idx=np.random.choice(test_x.shape[0],3)\n",
    "x,y=test_x[idx],test_y[idx]\n",
    "y_predict = forward(x)\n",
    "for i in range(3):\n",
    "    plt.figure(figsize=(3,3))\n",
    "    plt.imshow(np.reshape(x[i],(28,28)))\n",
    "    plt.show()\n",
    "    print(\"y_true:{},y_predict:{}\".format(np.argmax(y[i]),np.argmax(y_predict[i])))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
